Conversation
Summary of ChangesHello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the debugging capabilities for weight synchronization processes within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable debug feature for weight synchronization and adds support for mxfp8 quantization. However, critical security vulnerabilities were identified, primarily related to insecure deserialization and command injection. Specifically, the use of torch.load() without weights_only=True in the new debugging utility could lead to arbitrary code execution from malicious checkpoints. Furthermore, several training scripts are vulnerable to command injection due to direct interpolation of user-supplied arguments into shell commands. Beyond these security concerns, suggestions were made to enhance code quality by refining exception handling, reducing code duplication, and cleaning up module exports. Addressing these security issues is paramount.
| ) | ||
| return safe_open(path, framework="pt", device="cpu") | ||
| if self.fmt == "bin": | ||
| obj = torch.load(path, map_location="cpu") |
There was a problem hiding this comment.
The use of torch.load() without weights_only=True is insecure as it relies on the pickle module, which can execute arbitrary code during deserialization. An attacker could provide a malicious checkpoint file that, when loaded for debugging or comparison, executes arbitrary commands on the system. It is highly recommended to use weights_only=True to restrict deserialization to safe types.
| obj = torch.load(path, map_location="cpu") | |
| obj = torch.load(path, map_location="cpu", weights_only=True) |
| U.exec_command( | ||
| f"huggingface-cli download Qwen/{args.model_name}-FP8 --local-dir /root/models/{args.model_name}-FP8" | ||
| ) |
There was a problem hiding this comment.
The model_name argument is directly interpolated into a shell command string without sanitization. This allows for command injection if an attacker can control the model_name parameter. For example, a model_name like ; touch /tmp/pwned would result in the execution of the injected command. Use shlex.quote() to sanitize any variables used in shell commands.
scripts/run_qwen3_30b_a3b.py
Outdated
| U.exec_command( | ||
| f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}" | ||
| ) |
scripts/run_qwen3_4b.py
Outdated
|
|
||
| if args.rollout_fp8: | ||
| if args.rollout_fp8 and not use_blackwell_fp8: | ||
| U.exec_command(f"hf download Qwen/{args.model_name}-FP8 --local-dir /root/models/{args.model_name}-FP8") |
scripts/run_qwen3_4b.py
Outdated
| U.exec_command( | ||
| f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}" | ||
| ) |
| __all__ = ["remove_padding", "quantize_param", "quantize_params_fp8", "quantize_params_compressed_tensors"] | ||
| __all__ = [ | ||
| "remove_padding", | ||
| "quantize_param", |
There was a problem hiding this comment.
The __all__ list includes quantize_param, but this function is not defined or imported in this module. This appears to be a pre-existing issue, but since this block is being modified, it's a good opportunity to correct it. Removing this line will prevent potential NameError exceptions and improve code clarity.
| # experts | ||
| expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" | ||
| match = re.match(expert_pattern, rest) | ||
| if match: | ||
| rest, expert_idx = match.groups() | ||
| if rest in [ | ||
| "linear_fc1", | ||
| "linear_fc2", | ||
| ]: | ||
| quantize_named_params = [] | ||
| for converted_name, param in converted_named_params: | ||
| # skip bf16 weight_scale and input_scale | ||
| # TODO: find a clearer way. | ||
| if converted_name.endswith("_scale"): | ||
| continue | ||
| quantize_named_params.extend(_quantize_param(converted_name, param)) | ||
|
|
||
| return quantize_named_params | ||
|
|
||
| # shared expert | ||
| shared_expert_pattern = r"mlp.shared_experts\.(.+)" | ||
| match = re.match(shared_expert_pattern, rest) | ||
| if match: | ||
| rest = match.groups()[0] | ||
| if rest in [ | ||
| "linear_fc1.weight", | ||
| "linear_fc2.weight", | ||
| ]: | ||
| quantize_named_params = [] | ||
| for converted_name, param in converted_named_params: | ||
| quantize_named_params.extend(_quantize_param(converted_name, param)) | ||
|
|
||
| return quantize_named_params | ||
|
|
||
| if rest in [ | ||
| "self_attention.linear_proj.weight", | ||
| "self_attention.linear_qkv.weight", | ||
| "mlp.linear_fc1.weight", | ||
| "mlp.linear_fc2.weight", | ||
| # mla | ||
| "self_attention.linear_q_proj.weight", | ||
| "self_attention.linear_q_down_proj.weight", | ||
| "self_attention.linear_q_up_proj.weight", | ||
| "self_attention.linear_kv_down_proj.weight", | ||
| "self_attention.linear_kv_up_proj.weight", | ||
| ]: | ||
| quantize_named_params = [] | ||
| for converted_name, param in converted_named_params: | ||
| quantize_named_params.extend(_quantize_param(converted_name, param)) | ||
|
|
||
| return quantize_named_params |
There was a problem hiding this comment.
There is significant code duplication in how quantization is applied for different layer types (experts, shared experts, and other linear layers). The logic to iterate over converted_named_params and call _quantize_param is repeated.
This could be refactored into a helper function to improve maintainability and readability. For example:
def _apply_quantization(converted_named_params, skip_scales=False):
quantized_params = []
for name, param in converted_named_params:
if skip_scales and name.endswith("_scale"):
continue
quantized_params.extend(_quantize_param(name, param))
return quantized_params
# ... inside quantize_params_mxfp8, you can then determine if quantization is needed
# and call the helper, e.g.:
# if should_quantize:
# return _apply_quantization(converted_named_params, skip_scales=is_expert_layer)| except Exception: # pragma: no cover - optional dependency | ||
| safe_open = None |
There was a problem hiding this comment.
Catching a broad Exception for an optional import can hide other unexpected errors. It's better to catch the specific ImportError that occurs when the optional dependency is not installed.
| except Exception: # pragma: no cover - optional dependency | |
| safe_open = None | |
| except ImportError: # pragma: no cover - optional dependency | |
| safe_open = None |
| except Exception as exc: # pragma: no cover - optional dependency | ||
| logger.warning( | ||
| "Cannot resolve HF repo id %s (huggingface_hub unavailable): %s", | ||
| path_or_repo, | ||
| exc, | ||
| ) |
There was a problem hiding this comment.
Similar to the previous comment, catching a broad Exception for an optional import can mask other issues. It's more precise to catch ImportError here as well.
| except Exception as exc: # pragma: no cover - optional dependency | |
| logger.warning( | |
| "Cannot resolve HF repo id %s (huggingface_hub unavailable): %s", | |
| path_or_repo, | |
| exc, | |
| ) | |
| except ImportError as exc: # pragma: no cover - optional dependency | |
| logger.warning( | |
| "Cannot resolve HF repo id %s (huggingface_hub unavailable): %s", | |
| path_or_repo, | |
| exc, | |
| ) |
a18b407 to
fe5fc1a
Compare
fe5fc1a to
2c276a3
Compare
|
Thanks for contribution. I am looking into this PR. |
@HumansAnd
This PR conducts per-tensor bit-exact check for first weight sync.
Example: